require 'torch'
require 'nn'
require 'cunn'
require 'cutorch'
require 'cudnn'


function create_G()
    local model = nn.Sequential()
    model:add(nn.ConcatTable()
       :add(nn.Sequential()
          :add(nn.SpatialFullConvolution(1,64,5,5,1,1,2,2,0,0))
       )
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(1,64,1,1,1,1,0,0))
          :add(nn.SpatialBatchNormalization(64))
       )
    )
    model:add(nn.CAddTable())
    model:add(nn.SpatialBatchNormalization(64))
    model:add(nn.ReLU())
    
    model:add(nn.ConcatTable()
       :add(nn.Sequential()
          :add(nn.SpatialFullConvolution(64,128,5,5,1,1,2,2,0,0))
       )
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(64,128,1,1,1,1,0,0))
          :add(nn.SpatialBatchNormalization(128))
       )
    )
    model:add(nn.CAddTable())
    model:add(nn.SpatialBatchNormalization(128))
    model:add(nn.ReLU())

    model:add(nn.ConcatTable()
       :add(nn.Sequential()
          :add(nn.SpatialFullConvolution(128,64,3,3,1,1,1,1,0,0))
       )
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(128,64,1,1,1,1,0,0))
          :add(nn.SpatialBatchNormalization(64))
       )
    )
    model:add(nn.CAddTable())
    model:add(nn.SpatialBatchNormalization(64))
    model:add(nn.ReLU())

    model:add(nn.ConcatTable()
       :add(nn.Sequential()
          :add(nn.SpatialFullConvolution(64,1,3,3,1,1,1,1,0,0))
       )
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(64,1,1,1,1,1,0,0))
          :add(nn.SpatialBatchNormalization(1))
       )
    )
    model:add(nn.CAddTable())
    model:add(nn.SpatialBatchNormalization(1))
    model:add(nn.ReLU())

    return model
end

function create_D()
    local model = nn.Sequential()
    model:add(nn.ConcatTable()
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(1,64,3,3,1,1,1,1))
          :add(nn.SpatialMaxPooling(2,2))
       )
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(1,64,1,1,2,2,0,0))
          :add(nn.SpatialBatchNormalization(64))
       )
    )
    model:add(nn.CAddTable())
    model:add(nn.SpatialBatchNormalization(64))
    model:add(nn.ReLU())
    
    model:add(nn.ConcatTable()
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(64,128,3,3,1,1,1,1))
          :add(nn.SpatialMaxPooling(2,2))
       )
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(64,128,1,1,2,2,0,0))
          :add(nn.SpatialBatchNormalization(128))
       )
    )
    model:add(nn.CAddTable())
    model:add(nn.SpatialBatchNormalization(128))
    model:add(nn.ReLU())

    model:add(nn.ConcatTable()
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(128,256,5,5,1,1,2,2))
          :add(nn.SpatialMaxPooling(2,2))
       )
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(128,256,1,1,2,2,0,0))
          :add(nn.SpatialBatchNormalization(256))
       )
    )
    model:add(nn.CAddTable())
    model:add(nn.SpatialBatchNormalization(256))
    model:add(nn.ReLU())

    model:add(nn.ConcatTable()
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(256,128,5,5,1,1,2,2))
          :add(nn.SpatialMaxPooling(2,2))
       )
       :add(nn.Sequential()
          :add(nn.SpatialConvolution(256,128,1,1,2,2,0,0))
          :add(nn.SpatialBatchNormalization(128))
       )
    )
    model:add(nn.CAddTable())
    model:add(nn.SpatialBatchNormalization(128))
    model:add(nn.ReLU())
    
    model:add(nn.Reshape(32768))
    model:add(nn.Linear(32768, 64))
    model:add(nn.Linear(64, 1))
    model:add(nn.Sigmoid())

    return model
end

torch.save('G0.t7', create_G())
torch.save('D0.t7', create_D())